%%%%%%%%%%
% Authors : Soumendu Sundar Mukherjee, Purnamrita Sarkar
% License: LGPLv3 (inclusive of all the scripts in the 'helper scripts' folder)
% Description: Reference implementation of PACE and GALE with spectral clustering as the global algorithm
%%%%%%%%%% 

% Note: Keep 'local_patching.m' and 'helper scripts' in the same directory and cd into it 
% Load helper scripts
addpath ./'helper scripts'/

% Set seed for reproducibility
s = RandStream('mcg16807','Seed',0);
RandStream.setGlobalStream(s);

% Create parallel pool
nworkers = 8;
if(isempty(gcp('nocreate')))
    parpool(nworkers);
end

%%%% Various Simulation settings %%%%

%-- Setting 1 --%

n = 2000;
rho = 0.05;
p = 0.2;
r = 0.01;
q = r*p;
K = 2;
B = (p - q) * eye(K) + q * ones(K);
pi = [.2 .8];
avg_deg = (n - 1) * rho * pi * B * pi';

%-- Setting 2 --%

% n = 2000;
% K = 5;
% rho = 0.05;
% p = 0.3;
% q = 0.03;
% B = (p - q) * eye(K) + q * ones(K);
% % pi = ones(K, 1) / K;
% % pi = [0.15 0.1 0.3 0.35 0.1];
% pi= [0.15 0.25 0.2 0.25 0.15];
% avg_deg = (n - 1) * rho * pi * B * pi';

%-- Setting 3 --%

% n = 60000;
% K = 2;
% rho = 0.0005;
% p = 0.2;
% r = 0.01;
% q = p * r;
% B = (p - q) * eye(K) + q * ones(K);
% pi = [.2 .8];
% avg_deg =  (n - 1) * rho * pi * B * pi';

%-- Setting 4 --%

% n = 100000;
% K = 10;
% rho = 0.001;
% p = 1;
% r = 0.005; 
% q = p * r;
% B = (p - q) * eye(K) + q * ones(K);
% pi = ones(1, K) / K;
% pi = [.2 .8];
% avg_deg =  (n - 1) * rho * pi * B * pi';

% stdout
fprintf('\nGraph parameters\n');
fprintf('----------------------\n');
fprintf('n = %d\n', n);
fprintf('K = %d\n', K);
fprintf('Average degree = %.f\n', avg_deg);
fprintf('----------------------\n');

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%% Generate graph %%%%

fprintf('\nGenerating graph...\n')
time1 = tic;
[A, comm] = cbm_parallel(nworkers, n, rho, B, pi);
n = size(A, 1); K = max(comm);
timetaken.graph_generation = toc(time1);
fprintf('Time taken in creating the graph = %.2f seconds\n', timetaken.graph_generation);

% [~, I] = sort(comm);
% spy(A(I, I));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%% Global method %%%%
fprintf('Running global algorithm...\n')
[A1, ~, ~, I1] = process_real_graph(A, A, A, 1); % largest connected component of A
comm1 = comm(I1);
% length(I1)

time2 = tic;
[~, comm_est] = spectral(A1, K, 'regLaplacian', 'true', 'Rohe');
timetaken.global = toc(time2);
% fprintf('Time taken by the global algorithm = %f\n', timetaken.global);
error_global = cluster_acc(comm1, comm_est);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%% Computation on patches %%%%
fprintf('Performing computations on patches...\n')
h = 3;
m = 800; 
T = 100;
mstar = 10;
thresh = 1;

nbsAll = sparse(T, n);
ClustAll = sparse(T, n);

c = 0; % accepted subgraphs out of T
avg_subgraph_size = 0;

tstart = tic;

% Select roots of h-hop subgraphs 
% roots = randsample(n, T);
deg = sum(A);
hubs = find(deg >= 10);
roots = randsample(length(hubs), T);
roots = hubs(roots);

err_sub = zeros(T, 1);
parfor t = 1:T

    [As, verts] = hhop(A, roots(t), h);
    % [As, verts] = onion(A, roots(t), h);
    % [As, verts] = randsub(A, m);    
    [As, ~, ~, I] = process_real_graph(As, As, As, thresh); % largest connected component
	verts = verts(I);
    % length(I)
    if(length(I) >= mstar)
        c = c+1;
        avg_subgraph_size = avg_subgraph_size + length(verts);
        [~, comm_sub] = spectral(As, K, 'regLaplacian', 1, 'Rohe');
        error_sub = cluster_acc(comm_sub, comm(verts));
        err_sub(t) = error_sub;
        % error_sub
        v1 = zeros(1, n);
        v1(verts) = 1;
        nbsAll(t, :) = v1;
        % nbsAll(t, :) = sparse(verts, ones(1, n));
        v2 = zeros(1, n);
        v2(verts) = comm_sub;
        ClustAll(t, :) = v2;
    end
end

tot_time = toc(tstart);

timetaken.subgraphs_total = tot_time;
timetaken.subgraphs_avg = tot_time/T;

avg_subgraph_size = avg_subgraph_size/c;
fprintf('Average subgraph size = %.f\n', avg_subgraph_size);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%% PACE %%%%
fprintf('Performing PACE specific computations...\n')
tstart1 = tic;
N = sparse(n, n);
Ctemp = sparse(n, n);

% This creates a dense matrix, need a better way
for t = 1:T
    verts = find(nbsAll(t,:) == 1);
    comm_temp = ClustAll(t, verts);
    Zs = sparse(length(verts),K); 
    for i = 1:length(verts)
        Zs(i,comm_temp(i)) = 1;
    end
    Zext = sparse(n, K);
    Zext(verts, :) = Zs;
    N(verts, verts) = N(verts, verts) + 1;
    Ctemp = Ctemp + Zext * Zext';
end

perc = 0.4;
theta = 0.8; 
% tau = T * m * (m-1) / (n * (n - 1)) * theta;
tau = quantile(N(N > 0), perc);

C = N; % hack (slow otherwise)
C(N >= tau) = Ctemp(N >= tau) ./ N(N >= tau);
% Loop based computation of the above step (slow)
% for i = 1:n
%     for j = 1:n
%         if(N(i,j) >= tau)
%             C(i,j) = Ctemp(i,j)/N(i,j);
%         end
%     end
% end

% Parallel computation
% index_cell = cell(n - 1, 1);
% value_cell = cell(n - 1, 1);
% 
% parfor i = 1:(n - 1)
%     for j = (i + 1):n
%         temp = nbsAll(:, i)' * nbsAll(:, j);
%         temp
%         if(temp >= tau)
%             index_cell{i} = [index_cell{i}; i j];
%             ctemp = confusionmat(ClustAll(:, i), ClustAll(:, j));
% %             ctemp
%             value_cell{i} = [value_cell{i}; (trace(ctemp) - ctemp(1, 1))/temp];
%         end
%     end
% end
% 
% indices = vertcat(index_cell{:});
% values = vertcat(value_cell{:});
% 
% C = sparse(indices(:, 1), indices(:, 2), values, n , n);
% C = C + C' + speye(n);
                

timetaken.PACE_thresholding_C = toc(tstart1);

% Recoveing Z from C
tstart2 = tic;

recovery_method = 'spectral';

if(strcmp(recovery_method, 'kmeans'))
    C_proj = rproj(C, 50*floor(log(n))); % random projection
    % C_proj = C; % Don't do random projection 
    comm_both.PACE = mykmeans1(C_proj, K);   % kmeans clustering
else
    [~, comm_both.PACE] = spectral(C, K, 'unregLaplacian', 1, 'Rohe');
end

% Projection based on pca (faster, but does not seem to be very accurate)
% s = 2*floor(log(m));
% C_proj = zeros(n, s);
% 
% for t = 1:T
%     verts = find(nbsAll(t, :));
%     comm_sub = ClustAll(t, verts);
%     m1 = size(verts, 2);
%     Z_sub = sparse(1:m1, comm_sub, ones(m1, 1), m1, K);
%     Projection step
%     C_temp = Z_sub * Z_sub';
%     C_proj(verts, :) = C_proj(verts, :) + proj(C_temp, s, 'pca');
% end
%
% comm_both.PACE = mykmeans1(C_proj, K);

timetaken.PACE_recovering_Z = toc(tstart2);

err_PACE = cluster_acc(comm_both.PACE, comm);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%% GALE %%%%
fprintf('Performing GALE specific computations...\n')
tstart3 = tic;
[comm_both.GALE, ~] = consistent_wrapper1(nbsAll, ClustAll, K,0);
timetaken.GALE_patching_step = toc(tstart3);

err_GALE = cluster_acc(comm_both.GALE, comm);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% stdout
fprintf('\n-------------------------------------------------------------\n');
fprintf('   Algorithm  |  Misclustering error  |  Time taken (seconds)\n');
fprintf('-------------------------------------------------------------\n');
fprintf('      Global  |         %.4f        |        %.2f\n', error_global, timetaken.global);
fprintf('        PACE  |         %.4f        |        %.2f\n', err_PACE, (timetaken.subgraphs_total + timetaken.PACE_thresholding_C + timetaken.PACE_recovering_Z));
fprintf('        GALE  |         %.4f        |        %.2f\n', err_GALE, (timetaken.subgraphs_total + timetaken.GALE_patching_step));
fprintf('-------------------------------------------------------------\n\n');
